from __future__ import division, print_function, absolute_import

import numpy as np
import pickle
from matplotlib import pyplot as plt
from matplotlib.ticker import LogLocator, MultipleLocator
from CDTools.tools import plotting

#
# This section analyzes the results of the numerical experiment designed
# to test the achievable resolution of our algorithm
#

# This will keep track of the overall error distribution across
# all the experiments
errors = []

plt.figure(figsize=(3.5,2.5))

for zp_maxk in [100,150, 200]:

    with open('data/resolution trial ' + str(zp_maxk) + ' 1 1000.pickle', 'rb') as f:
        results = pickle.load(f)

    fraction_failed = []

    # First we have a set of diffraction patterns at each resolution
    # For each diffraction pattern there is an ensemble of reconstructions
    for i, res_errors in enumerate(results['errors']):
        # We start by recording the total error distribution
        errors.extend(res_errors[0])
        # And then we calculate the fraction which have failed to converge
        # below 1% error
        fraction_failed.append(np.mean(np.array(res_errors[0])>0.001))

    # And we plot the results agains the resolution ratio
    # 'resolutions' is storing the size of the object array in pixels,
    # which is twice k_0. Hence the factor of two in the calculation of
    # the resolution ratio
    resolutions = np.array(results['resolutions'])
    plt.plot(resolutions/(2*zp_maxk),fraction_failed, label=str(zp_maxk))

    #if zp_maxk == 200:
    #    for i, (res_errors, res_losses) in enumerate(zip(results['errors'],results['losses'])):
    #        if i%3 == 0:
    #            ratio = resolutions[i]/(2*zp_maxk)
    #            plt.loglog(res_losses[0],res_errors[0],'.',label='R=%0.2f'%ratio)
#plt.xlabel('Diffraction Loss')
#plt.ylabel('Image Error')
#plt.legend(loc='lower right')
#plt.show()
#exit()
    

    
# We finish and label this plot
plt.ylabel('Fraction Failed')
plt.xlabel('Resolution Ratio R')
plt.legend()
plt.tight_layout()

# Now we plot the overall error distribution
plt.figure(figsize=(3.5,2.5))
plt.hist(np.log(errors)/np.log(10),bins=50, weights=np.ones_like(errors)/len(errors))
plt.ylabel('Fraction of Total')
plt.xlabel('Log base 10 of Error')
plt.tight_layout()


#
# This section analyzes the results of the numerical experiment designed
# to test the robustness to leaked spectral weight
#


#with open('data/bandlimiting robustness trial.pickle', 'rb') as f:
#    results = pickle.load(f)

#power_ratios = results['power_ratios']
#median_errors = [np.median(err) for err in results['errors']]

#plt.figure(figsize=(3.5,2.5))
#plt.plot(power_ratios, median_errors)
#plt.xlabel('Power Ratio')
#plt.ylabel('Median RMS Error')
#plt.tight_layout()

#
# This section analyzes the results of the numerical experiment designed
# to test the robustness to poisson noise
#


with open('data/poisson trial.pickle', 'rb') as f:
    results = pickle.load(f)
    
photons_per_pixel = results['photons per pixel']
median_errors = [ np.median(err) for err in results['errors']]


ml = LogLocator(base=10, numticks=10)
plt.figure(figsize=(3.25,2.5))
plt.loglog(photons_per_pixel, median_errors, linewidth=2)
plt.xlabel('Photons Per Pixel')
plt.ylabel('Median RMS Error')
plt.gca().xaxis.set_major_locator(ml)
plt.grid()
plt.tight_layout()



#
# This section analyzes the results of the numerical experiment designed
# to test the robustness to unknown probe defocus
#


with open('data/propagation trial.pickle', 'rb') as f:
    results = pickle.load(f)


final_losses = results['losses']
errors = results['errors']
defocuses = results['defocuses']

median_errors = [np.median(err) for err in results['errors']]
median_losses = [np.median(loss) for loss in results['losses']]


fig, ax1 = plt.subplots(figsize=(4.15,2.5))
ax1.plot(defocuses, median_errors, 'C0')
ax1.set_xlabel('Defocus (Depths of Focus)')
ax1.set_ylabel('Median RMS Error', color='C0')
ax2 = ax1.twinx()
ax2.plot(defocuses, median_losses,'C1', zorder=2)
ax2.set_ylabel('Median Diffraction Loss',color='C1')
ax2.xaxis.set_major_locator(MultipleLocator(1))
ax2.yaxis.set_major_locator(MultipleLocator(0.03))
plt.tight_layout()



#
# This section plots a complete set of metrics describing a typical
# reconstruction 
#


with open('data/example reconstruction.pickle', 'rb') as f:
    results = pickle.load(f)

retrieval = results['retrievals'][0]
retrieval *= np.sum(results['image']) / np.sum(retrieval)

plt.figure()
plt.imshow(np.abs(results['image']))
plt.colorbar()
plt.figure()
plt.imshow(np.abs(results['retrievals'][0]-results['image']))
plt.colorbar()
plt.figure()
plt.imshow(np.angle(results['image']),cmap='twilight')
plt.colorbar()
plotting.plot_colorized(results['probe'])
plt.figure(figsize=(3.5,2.5))
for loss in results['losses']:
    plt.semilogy(loss)
plt.xlabel('Iteration #')
plt.ylabel('Diffraction Loss')
plt.tight_layout()


#
# This section plots the results of the numerical experiments on missing
# data
#


with open('data/missing data trial.pickle', 'rb') as f:
    results = pickle.load(f)


errors = np.array(results['errors'])
resolutions = results['resolutions']
missing_radii = np.array(list(results['missing band radii']))

kp = 128 # later get this from file

#plt.close('all')

plt.figure(figsize=(3.25,2.5))
for resolution, errs in zip(resolutions, errors):
    
    fraction_failed = np.mean(np.array(errs)>0.001, axis=1)
    
    plt.semilogy(missing_radii*2/kp, np.median(errs, axis=1), label='R=%0.3f'%(resolution/(2*kp)))
    #plt.plot(missing_radii*2/(2*kp), fraction_failed, label='R=%0.3f'%(resolution/(2*kp)))

plt.xlabel('Dead Band Width (k_p)')
plt.ylabel('Median RMS Error')
plt.legend()
plt.tight_layout()


with open('data/uniformity trial.pickle', 'rb') as f:
    results = pickle.load(f)

plt.figure(figsize=(3.5,2.5))
# This is at a resolution ratio of 0.6, beamstop ratio of 0.5
plt.imshow(results['intensities'][1][3][0])
cbar = plt.colorbar()
cbar.set_label('Intensity (a.u.)')

for resolution_ratio, intensity_stack in zip(results['resolution ratios'],
                            results['intensities']):
    
    pad = int(128 * 2 * 0.25 * resolution_ratio)
    plt.figure(figsize=(3.5,2.5))
    for bs_factor, intensity_ims in zip(results['beamstop_factors'],
                                        intensity_stack):
        raveled = np.array(intensity_ims)[:,pad:-pad,pad:-pad].ravel()
        raveled /= np.mean(raveled)
        hist, bin_edges = np.histogram(raveled,bins=80)#np.linspace(0,4,101))
        q1 = np.quantile(raveled,0.001)
        q50 = np.quantile(raveled,0.50)
        #print(q1/q50, np.std(raveled))
        xs = (bin_edges[1:] + bin_edges[:-1])/2
        hist = hist / np.sum((xs[1]-xs[0])*hist)
        if int(round(bs_factor*6)) == 4:
            continue # Just to make the plot easier to read
        plt.plot(xs,hist, label='BS='+str(int(round(bs_factor*6)))+'/6')

    plt.gca().set_xlim(left=0)
    #plt.title('Pixel Intensity Distribution for R=%0.1f' % resolution_ratio)
    plt.ylabel('PDF (R=%0.1f)'%resolution_ratio)
    plt.xlabel('Intensity')
    plt.legend()
    plt.tight_layout()
    

    

    
with open('data/uniformity trial 2.pickle', 'rb') as f:
    results = pickle.load(f)


plt.figure(figsize=(3.5,2.5))
#plt.close('all')

p1 = []
p01 = []
p001 = []
p0001 = []
for resolution_ratio, intensity_stack in zip(results['resolution ratios'],
                            results['intensities']):

    #plt.close('all')
    #plt.imshow(intensity_stack[0])
    #plt.colorbar()
    #plt.show()
    pad = int(128 * 2 * 0.25 * resolution_ratio)
    # only the central region was saved anyway
    raveled = np.array(intensity_stack).ravel()
    raveled /= np.mean(raveled)
    hist, bin_edges = np.histogram(raveled,bins=100)#np.linspace(0,4,101))
    q1 = np.quantile(raveled,0.01)
    q01 = np.quantile(raveled,0.001)
    q001 = np.quantile(raveled,0.0001)
    q0001 = np.quantile(raveled,0.00001)
    q50 = np.quantile(raveled,0.50)
    p1.append(q1/q50)
    p01.append(q01/q50)
    p001.append(q001/q50)
    p0001.append(q0001/q50)
    xs = (bin_edges[1:] + bin_edges[:-1])/2
    hist = hist / np.sum((xs[1]-xs[0])*hist)
    #if int(round(resolution_ratio*10)) in [1,2,3,4,5,7,10,15,20]:
    if int(round(resolution_ratio*10)) in [1,3,5,10,20]:
        plt.plot(xs,hist, label='R=%0.1f'%resolution_ratio)

plt.gca().set_xlim(left=0)
plt.ylabel('PDF (BS=0.5)')
plt.xlabel('Intensity')
plt.legend()
plt.tight_layout()
plt.figure(figsize=(3.5,2.5))
plt.plot(results['resolution ratios'],p1, label='P=10%')
plt.plot(results['resolution ratios'],p01, label='P=1%')
plt.plot(results['resolution ratios'],p001, label='P=0.1%')
#plt.plot(results['resolution ratios'],p0001, label='P=0.01%')
plt.xlabel('Resolution Ratio R')
plt.ylabel('Pth Percentile')
plt.ylim([0,1])
plt.legend()
plt.tight_layout()




with open('data/bandlimiting usefulness trial.pickle', 'rb') as f:
    results = pickle.load(f)


errors = np.array(results['errors'])
resolutions = results['resolutions']
objs = results['obj']
losses = np.array(results['losses'])

kp = 128

#plt.close('all')

obj_sizes = [obj.shape[0] for obj in objs]
Rs = [obj.shape[0]/results['probe'].shape[0] for obj in objs]


plt.figure(figsize=(3.25,2.5))
for R, obj_size, resolution, errs in zip(Rs, obj_sizes, resolutions, errors):
    
    fraction_failed = np.mean(np.array(errs)>0.001, axis=1)
    
    #plt.semilogy(resolution, np.median(errs, axis=1),label='R=%0.3f'%R)
    plt.plot(resolution/results['probe'].shape[0], fraction_failed, label='R=%0.3f'%R)

plt.xlabel(r'Resolution Ratio $R_{rec}$')
#plt.ylabel('Median RMS Error')
plt.ylabel('Fraction Failed')
plt.legend()
plt.tight_layout()


plt.show()
